{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tokenizers import decoders, models, pre_tokenizers, trainers, Tokenizer\n",
    "import os \n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取数据\n",
    "def read_data(path):\n",
    "    with open(path, 'r', encoding='utf-8') as f:\n",
    "        for line in f:\n",
    "            data = json.loads(line)\n",
    "            yield data['text']  \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# BPE分词器\n",
    "tokenizer = Tokenizer(models.BPE())\n",
    "tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义特殊tokens\n",
    "special_tokens = ['<pad>', '<unk>', '<s>', '</s>']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 初始化训练器\n",
    "trainer = trainers.BpeTrainer(\n",
    "    vocab_size=6400,\n",
    "    special_tokens = special_tokens,\n",
    "    show_progress=True,\n",
    "    initial_alphabet = pre_tokenizers.ByteLevel.alphabet()\n",
    "    \n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "texts = read_data('./dataset/tokenizer_train.jsonl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "tokenizer.train_from_iterator(texts, trainer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.decoder = decoders.ByteLevel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer_dir = \"./tokenizer\"\n",
    "os.makedirs(tokenizer_dir, exist_ok=True)\n",
    "tokenizer.save(os.path.join(tokenizer_dir, \"tokenizer.json\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['./tokenizer/vocab.json', './tokenizer/merges.txt']"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.model.save(tokenizer_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "        \"add_bos_token\": False,\n",
    "        \"add_eos_token\": False,\n",
    "        \"add_prefix_space\": True,\n",
    "        \"added_tokens_decoder\": {\n",
    "            \"0\": {\n",
    "                \"content\": \"<unk>\",\n",
    "                \"lstrip\": False,\n",
    "                \"normalized\": False,\n",
    "                \"rstrip\": False,\n",
    "                \"single_word\": False,\n",
    "                \"special\": True\n",
    "            },\n",
    "            \"1\": {\n",
    "                \"content\": \"<s>\",\n",
    "                \"lstrip\": False,\n",
    "                \"normalized\": False,\n",
    "                \"rstrip\": False,\n",
    "                \"single_word\": False,\n",
    "                \"special\": True\n",
    "            },\n",
    "            \"2\": {\n",
    "                \"content\": \"</s>\",\n",
    "                \"lstrip\": False,\n",
    "                \"normalized\": False,\n",
    "                \"rstrip\": False,\n",
    "                \"single_word\": False,\n",
    "                \"special\": True\n",
    "            }\n",
    "        },\n",
    "        \"additional_special_tokens\": [],\n",
    "        \"bos_token\": \"<s>\",\n",
    "        \"clean_up_tokenization_spaces\": False,\n",
    "        \"eos_token\": \"</s>\",\n",
    "        \"legacy\": True,\n",
    "        \"model_max_length\": 100000,\n",
    "        \"pad_token\": None,\n",
    "        \"sp_model_kwargs\": {},\n",
    "        \"spaces_between_special_tokens\": False,\n",
    "        \"tokenizer_class\": \"PreTrainedTokenizerFast\",\n",
    "        \"unk_token\": \"<unk>\",\n",
    "        \"use_default_system_prompt\": False,\n",
    "        \"chat_template\": \"{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<s>user\\\\n' + content + '</s>\\\\n<s>assistant\\\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '</s>' + '\\\\n' }}{% endif %}{% endfor %}\"\n",
    "    }\n",
    "\n",
    "# 保存配置文件\n",
    "with open(os.path.join(tokenizer_dir, \"tokenizer_config.json\"), \"w\", encoding=\"utf-8\") as config_file:\n",
    "    json.dump(config, config_file, ensure_ascii=False, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[804, 588]"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 测试\n",
    "from transformers import AutoTokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"./tokenizer\")\n",
    "tokenizer.encode(\"您好\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([0], [1], [2], [3])"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.encode(\"<pad>\"), tokenizer.encode(\"<unk>\"), tokenizer.encode(\"<s>\"), tokenizer.encode(\"</s>\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'好'"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(588)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "6400"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.vocab_size"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "wyf",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
